/// Reference:
/// - https://en.wikipedia.org/wiki/Burrows%E2%80%93Wheeler_transform
/// - https://en.wikipedia.org/wiki/FM-index
/// - https://www.cs.jhu.edu/~langmea/resources/lecture_notes/bwt_and_fm_index.pdf
/// - https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/bwt.pdf

use std::{io::{BufReader, BufWriter, Write, Error as IOError, Read, Seek, SeekFrom}, time::Instant};
use crate::common::libsais::{OperationError, libsais_u8_i32};

#[inline]
fn handle_ior_nocheck<T>(ior: Result<T, IOError>) -> Result<(), OperationError> {
    if let Ok(_) = ior {
        Ok(())
    } else {
        Err(OperationError(3))
    }
}

#[inline]
fn handle_ior_check_len(ior: Result<usize, IOError>, len: usize) -> Result<(), OperationError> {
    if let Ok(n) = ior {
        if n < len {
            println!("IO Op demands {}, but only get {} done!", len, n);
            Err(OperationError(4))
        } else {
            Ok(())
        }
    } else {
        Err(OperationError(3))
    }
}

/// Perform BWT using the suffix array approach, generate auxiliary data structures
/// for efficient L (last column) to F (first column) mapping needed by the decoder.
pub fn transform<R: Read, W: Write>(
    len: i32,
    reader: &mut R,
    writer_out: &mut W,
    writer_c: &mut W,
    writer_occ: &mut W,
) -> Result<i32, OperationError> {
    if len <= 0 {
        return Ok(0);
    }

    let mut input_buf: Vec<u8> = vec![0u8; len as usize];
    handle_ior_check_len(reader.read(&mut input_buf), len as usize)?;

    let t0 = Instant::now();

    let (freq_opt, sa) = libsais_u8_i32(&input_buf, 0, true)?;

    let d1 = t0.elapsed();
    println!("[bwt] SA computing takes: {:?}", d1);

    let mut c = freq_opt.unwrap();

    let mut acc_sum: i32 = 1;
    for x in c.iter() {
        handle_ior_check_len(writer_c.write(&acc_sum.to_ne_bytes()), 4)?;
        acc_sum += x;
    }
    
    let d2 = t0.elapsed();
    println!("[bwt] Writing C array takes: {:?}", d2 - d1);

    c.fill(0i32);
    let mut endmark_pos = 0i32;
    let mut onebyte_buf = [0u8;1];

    let mut write_one_converted_and_occ = |idx: i32| -> Result<(), OperationError>{
        let b = input_buf[idx as usize];
        onebyte_buf[0] = b;
        handle_ior_check_len(writer_out.write(&onebyte_buf), 1)?;
        handle_ior_check_len(writer_occ.write(&c[b as usize].to_ne_bytes()), 4)?;
        c[b as usize] += 1;
        Ok(())
    };

    write_one_converted_and_occ(len - 1)?;

    for (i, suffix_idx) in sa.into_iter().enumerate() {
        if suffix_idx > 0 {
            write_one_converted_and_occ(suffix_idx - 1)?;
        } else if suffix_idx == 0 {
            endmark_pos = (i+1) as i32;
            println!("[bwt] endmark_pos: {}", endmark_pos);
        } else {
            panic!("Should not contain negative number here!");
        }
    }

    handle_ior_check_len(writer_out.write(&endmark_pos.to_ne_bytes()), 4)?;

    let d2 = t0.elapsed();
    println!("[bwt] Writing Occ array & out takes: {:?}", d2 - d1);
    Ok(len)
}

/// Reverse BWT using LF mapping. This approach normally re-generate the original 
/// content in reversed order, set in_order to true to make the output in correct order.
pub fn reverse_transform<R: Read, W: Write + Seek>(
    len: i32,
    reader: &mut R,
    reader_c: &mut R,
    reader_occ: &mut R,
    writer_out: &mut W,
    in_order: bool
) -> Result<i32, OperationError> {
    if len <= 0 {
        return Ok(0);
    }

    let t0 = Instant::now();

    let ulen = len as usize;

    let mut input = vec![0u8; ulen + 4].into_boxed_slice();
    let mut occ = vec![0i32; ulen].into_boxed_slice();
    let mut c = [0i32; 256];

    handle_ior_check_len(reader.read(&mut input), ulen + 4)?;

    let (_, mid, _) = unsafe { occ.align_to_mut::<u8>() };
    handle_ior_check_len(reader_occ.read(mid), 4 * ulen)?;

    let (_, mid, _) = unsafe { c.align_to_mut::<u8>() };
    handle_ior_check_len(reader_c.read(mid), 4 * 256)?;

    let endmark_pos = i32::from_ne_bytes(input[ulen..ulen+4].try_into().unwrap());
    println!("[reverse-bwt] endmark_pos: {}", endmark_pos);
   
    if in_order {
        handle_ior_nocheck(writer_out.seek(SeekFrom::Start((ulen - 2) as u64)))?;
    }

    let d1 = t0.elapsed();
    println!("[reverse-bwt] All reads takes: {:?}", d1);

    let mut onebyte_buf = [input[0]];
    handle_ior_check_len(writer_out.write(&onebyte_buf), 1)?;

    let mut l = 0i32;
    loop {
        let f = {
            let l_usz = l as usize;
            let sym_usz = input[l_usz] as usize;
            c[sym_usz] + occ[l_usz]
        };
        if f == endmark_pos {
            break
        }

        let pre_sym_idx = if f > endmark_pos { f - 1 } else { f };
        
        onebyte_buf[0] = input[pre_sym_idx as usize];
        if in_order {
            handle_ior_nocheck(writer_out.seek(SeekFrom::Current(-2)))?;
        }
        handle_ior_check_len(writer_out.write(&onebyte_buf), 1)?;

        l = pre_sym_idx;
    }

    let d2 = t0.elapsed();
    println!("[reverse-bwt] Main transform loop takes: {:?}", d2 - d1);

    Ok(len)
}
